{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Stratification\n",
    "\n",
    "In the previous notebooks, we always used either a default `KFold` or a\n",
    "`ShuffleSplit` cross-validation strategies to iteratively split our dataset.\n",
    "However, you should not assume that these approaches are always the best\n",
    "option: some other cross-validation strategies might be better adapted to your\n",
    "problem.\n",
    "\n",
    "Let's start with the concept of stratification by giving an example where we\n",
    "can get into trouble if we are not careful. Let's load the iris dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "\n",
    "data, target = load_iris(as_frame=True, return_X_y=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, we create a basic machine-learning model: a logistic\n",
    "regression. We expect this model to work quite well on the iris dataset since\n",
    "this is a toy dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "model = make_pipeline(StandardScaler(), LogisticRegression())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model is created, we can evaluate it using cross-validation. We start\n",
    "by using the `KFold` strategy.\n",
    "\n",
    "Let's review how this strategy works. For such purpose, we define a dataset\n",
    "with nine samples and split the dataset into three folds (i.e. `n_splits=3`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "data_random = np.random.randn(9, 1)\n",
    "cv = KFold(n_splits=3)\n",
    "for train_index, test_index in cv.split(data_random):\n",
    "    print(\"TRAIN:\", train_index, \"TEST:\", test_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By defining three splits, we use three samples (1-fold) for testing and six\n",
    "(2-folds) for training each time. `KFold` does not shuffle by default. It\n",
    "means that the three first samples are selected for the testing set at the\n",
    "first split, then the three next three samples for the second split, and the\n",
    "three next for the last split. In the end, all samples have been used in\n",
    "testing at least once among the different splits.\n",
    "\n",
    "Now, let's apply this strategy to check the generalization performance of our\n",
    "model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import cross_validate\n",
    "\n",
    "cv = KFold(n_splits=3)\n",
    "results = cross_validate(model, data, target, cv=cv)\n",
    "test_score = results[\"test_score\"]\n",
    "print(\n",
    "    f\"The average accuracy is {test_score.mean():.3f} \u00b1 {test_score.std():.3f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is a real surprise that our model cannot correctly classify any sample in\n",
    "any cross-validation split. We now check our target's value to understand the\n",
    "issue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "target.plot()\n",
    "plt.xlabel(\"Sample index\")\n",
    "plt.ylabel(\"Class\")\n",
    "plt.yticks(target.unique())\n",
    "_ = plt.title(\"Class value in target y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the target vector `target` is ordered. This has some unexpected\n",
    "consequences when using the `KFold` cross-validation. To illustrate the\n",
    "consequences, we show the class count in each fold of the cross-validation in\n",
    "the train and test set.\n",
    "\n",
    "Let's compute the class counts for both the training and testing sets using\n",
    "the `KFold` cross-validation, and plot these information in a bar plot.\n",
    "\n",
    "We iterate given the number of split and check how many samples of each are\n",
    "present in the training and testing set. We then store the information into\n",
    "two distinct lists; one for the training set and one for the testing set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "n_splits = 3\n",
    "cv = KFold(n_splits=n_splits)\n",
    "\n",
    "train_cv_counts = []\n",
    "test_cv_counts = []\n",
    "for fold_idx, (train_idx, test_idx) in enumerate(cv.split(data, target)):\n",
    "    target_train, target_test = target.iloc[train_idx], target.iloc[test_idx]\n",
    "\n",
    "    train_cv_counts.append(target_train.value_counts())\n",
    "    test_cv_counts.append(target_test.value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To plot the information on a single figure, we concatenate the information\n",
    "regarding the fold within the same dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts = pd.concat(\n",
    "    train_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "train_cv_counts.index.name = \"Class label\"\n",
    "train_cv_counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_cv_counts = pd.concat(\n",
    "    test_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "test_cv_counts.index.name = \"Class label\"\n",
    "test_cv_counts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can represent graphically this information with bar plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Training set class counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Test set class counts\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can confirm that in each fold, only two of the three classes are present in\n",
    "the training set and all samples of the remaining class is used as a test set.\n",
    "So our model is unable to predict this class that was unseen during the\n",
    "training stage.\n",
    "\n",
    "One possibility to solve the issue is to shuffle the data before splitting the\n",
    "data into three groups."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cv = KFold(n_splits=3, shuffle=True, random_state=0)\n",
    "results = cross_validate(model, data, target, cv=cv)\n",
    "test_score = results[\"test_score\"]\n",
    "print(\n",
    "    f\"The average accuracy is {test_score.mean():.3f} \u00b1 {test_score.std():.3f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get results that are closer to what we would expect with an accuracy above\n",
    "90%. Now that we solved our first issue, it would be interesting to check if\n",
    "the class frequency in the training and testing set is equal to our original\n",
    "set's class frequency. It would ensure that we are training and testing our\n",
    "model with a class distribution that we would encounter in production."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts = []\n",
    "test_cv_counts = []\n",
    "for fold_idx, (train_idx, test_idx) in enumerate(cv.split(data, target)):\n",
    "    target_train, target_test = target.iloc[train_idx], target.iloc[test_idx]\n",
    "\n",
    "    train_cv_counts.append(target_train.value_counts())\n",
    "    test_cv_counts.append(target_test.value_counts())\n",
    "train_cv_counts = pd.concat(\n",
    "    train_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "test_cv_counts = pd.concat(\n",
    "    test_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "train_cv_counts.index.name = \"Class label\"\n",
    "test_cv_counts.index.name = \"Class label\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Training set class counts\\n(with suffling)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Test set class counts\\n(with suffling)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that neither the training and testing sets have the same class\n",
    "frequencies as our original dataset because the count for each class is\n",
    "varying a little.\n",
    "\n",
    "However, one might want to split our data by preserving the original class\n",
    "frequencies: we want to **stratify** our data by class. In scikit-learn, some\n",
    "cross-validation strategies implement the stratification; they contain\n",
    "`Stratified` in their names."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "cv = StratifiedKFold(n_splits=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = cross_validate(model, data, target, cv=cv)\n",
    "test_score = results[\"test_score\"]\n",
    "print(\n",
    "    f\"The average accuracy is {test_score.mean():.3f} \u00b1 {test_score.std():.3f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts = []\n",
    "test_cv_counts = []\n",
    "for fold_idx, (train_idx, test_idx) in enumerate(cv.split(data, target)):\n",
    "    target_train, target_test = target.iloc[train_idx], target.iloc[test_idx]\n",
    "\n",
    "    train_cv_counts.append(target_train.value_counts())\n",
    "    test_cv_counts.append(target_test.value_counts())\n",
    "train_cv_counts = pd.concat(\n",
    "    train_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "test_cv_counts = pd.concat(\n",
    "    test_cv_counts, axis=1, keys=[f\"Fold #{idx}\" for idx in range(n_splits)]\n",
    ")\n",
    "train_cv_counts.index.name = \"Class label\"\n",
    "test_cv_counts.index.name = \"Class label\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Training set class counts\\n(with stratifying)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_cv_counts.plot.bar()\n",
    "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n",
    "plt.ylabel(\"Count\")\n",
    "_ = plt.title(\"Test set class counts\\n(with stratifying)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, we observe that the class counts are very close both in the\n",
    "train set and the test set. The difference is due to the small number of\n",
    "samples in the iris dataset.\n",
    "\n",
    "In other words, stratifying is more effective than just shuffling when it\n",
    "comes to making sure that the distributions of classes in all the folds are\n",
    "representative of the entire dataset. As training and testing folds have\n",
    "similar class distributions, stratifying leads to a more realistic measure of\n",
    "the model\u2019s ability to generalize. This is specially important when the\n",
    "performance metrics depend on the proportion of the positive class, as we will\n",
    "see in a future notebook.\n",
    "\n",
    "The interested reader can learn about other stratified cross-validation\n",
    "techniques in the [scikit-learn user\n",
    "guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-with-stratification-based-on-class-labels)."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}